Solutions/ESET Protect Platform/Data Connectors/integration/utils.py (276 lines of code) (raw):

import asyncio import logging import typing as t import urllib.parse from datetime import datetime, timedelta, timezone from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError from cryptography.fernet import Fernet, InvalidToken from pydantic import ValidationError from azure.core.exceptions import HttpResponseError, ServiceRequestError from azure.data.tables import TableServiceClient from azure.identity.aio import DefaultAzureCredential from azure.monitor.ingestion.aio import LogsIngestionClient from integration.exceptions import ( AuthenticationException, InvalidCredentialsException, MissingCredentialsException, TokenRefreshException, ) from integration.models import Config, EnvVariables, TokenStorage from integration.models_detections import Detection, Detections class RequestSender: def __init__(self, config: Config, env_vars: EnvVariables): self.config = config self.env_vars = env_vars async def send_request( self, send_request_fun: ( t.Callable[ [ClientSession, dict[str, t.Any] | None, str, str | None, int, str], t.Coroutine[t.Any, t.Any, dict[str, str | int] | t.Any], ] | t.Callable[ [ClientSession, dict[str, t.Any] | None, str | None], t.Coroutine[t.Any, t.Any, dict[str, str | int] | t.Any], ] ), session: ClientSession, headers: dict[str, t.Any] | None = None, *data: t.Any, ) -> t.Dict[str, str | int] | None: retries = 0 while retries < self.config.max_retries: try: return await send_request_fun(session, headers, *data) except ClientResponseError as e: if e.headers: logging.info(f"Request-ID: {e.headers.get('Request-ID')}") if e.status in [400, 401, 403]: raise AuthenticationException(status=e.status, message=e.message) if e.status == 404: logging.info(f"Endpoint not found.") return None retries += 1 logging.error( f"Exception: {e.status} {e.message}. Request failed. " f"Request retry attempt: {retries}/{self.config.max_retries}" ) await asyncio.sleep(self.config.retry_delay) return None async def send_request_post( self, session: ClientSession, headers: dict[str, t.Any] | None, grant_type: str | None ) -> t.Dict[str, str | int] | t.Any: logging.info("Sending token request") async with session.post( url=f"{self.env_vars.oauth_url}/oauth/token", headers=headers, data=urllib.parse.quote(f"grant_type={grant_type}", safe="=&/"), timeout=self.config.requests_timeout, ) as response: return await response.json() async def send_request_get( self, session: ClientSession, headers: dict[str, t.Any] | None, last_detection_time: str, next_page_token: str | None, page_size: int, data_endpoint: str, ) -> t.Dict[str, str | int] | t.Any: logging.info("Sending service request") async with session.get( self.env_vars.detections_url + data_endpoint, headers=headers, params=self._prepare_get_request_params(last_detection_time, next_page_token, page_size), ) as response: return await response.json() def _prepare_get_request_params( self, last_detection_time: str, next_page_token: str | None, page_size: int = 100 ) -> dict[str, t.Any]: params = {"pageSize": page_size} if next_page_token not in ["", None]: params["pageToken"] = next_page_token # type: ignore[assignment] if last_detection_time: params["startTime"] = last_detection_time # type: ignore[assignment] return params class TokenProvider: def __init__(self, token: TokenStorage, requests_sender: RequestSender, env_vars: EnvVariables, buffer: int): self.token = token self.requests_sender = requests_sender self.env_vars = env_vars self.buffer = buffer self.fernet = Fernet(self.env_vars.key_base64.encode("utf-8")) self.storage_table_name = "TokenParams" self.storage_table_handler = StorageTableHandler(self.env_vars.conn_str, self.storage_table_name) self.storage_table_handler.set_entity() self.get_token_params_from_storage() def get_token_params_from_storage(self) -> None: if not self.storage_table_handler.entities: return None for token_param in self.token.to_dict().keys(): value = self.storage_table_handler.entities.get(token_param) if isinstance(value, bytes): try: value = self.fernet.decrypt(value).decode("utf-8") except InvalidToken: logging.warning("Issue with decrypt: Invalid Token") value = "" setattr(self.token, token_param, value) async def get_token(self, session: ClientSession) -> None: if not self.token.access_token or datetime.now(timezone.utc) > self.token.expiration_time: # type: ignore logging.info("Getting token") if not self.token.access_token and (not self.env_vars.username or not self.env_vars.password): raise MissingCredentialsException() grant_type = ( f"refresh_token&refresh_token={self.token.refresh_token}" if self.token.access_token else f"password&username={self.env_vars.username}&password={self.env_vars.password}" ) try: response = await self.requests_sender.send_request( self.requests_sender.send_request_post, session, {"Content-type": "application/x-www-form-urlencoded", "3rd-integration": "MS-Sentinel"}, grant_type, ) except AuthenticationException as e: if not self.token.access_token: raise InvalidCredentialsException(e) else: self.storage_table_handler.input_entity({k: "" for k in self.token.to_dict()}) # type: ignore[call-arg] raise TokenRefreshException(e) if response: self.set_token_params_locally_and_in_storage(response) logging.info("Token obtained successfully") def set_token_params_locally_and_in_storage(self, response: t.Dict[str, str | int]) -> None: self.token.access_token = t.cast(str, response["access_token"]) self.token.refresh_token = t.cast(str, response["refresh_token"]) self.token.expiration_time = datetime.now(timezone.utc) + timedelta( seconds=int(response["expires_in"]) - self.buffer ) self.storage_table_handler.input_entity( { k: self.fernet.encrypt(v.encode("utf-8")) if type(v) is str else v for k, v in self.token.to_dict().items() } ) # type: ignore[call-arg] class TransformerDetections: def __init__(self, env_vars: EnvVariables) -> None: self.env_vars = env_vars async def send_integration_detections( self, detections: dict[str, t.Any] | None, last_detection: str | None ) -> tuple[str | None, bool]: validated_detections = self._validate_detections_data(detections) if not validated_detections: return last_detection, False return await self._send_data_to_log_analytics_workspace(validated_detections, last_detection) def _validate_detections_data(self, response_data: dict[str, t.Any] | None) -> list[dict[str, t.Any]] | None: if not response_data: logging.info("No new detections") return None response_data["detections"] = ( response_data.pop("detectionGroups") if "detectionGroups" in response_data else response_data["detections"] ) try: validated_data = Detections.model_validate(response_data) self._update_time_generated(validated_data.detections) return validated_data.model_dump().get("detections") except ValidationError as e: logging.error(e) validated_detections = [] for detection in response_data.get("detections"): # type: ignore try: validated_detections.append(Detection.model_validate(detection)) except ValidationError as e: logging.error(e) self._update_time_generated(validated_detections) return [d.model_dump() for d in validated_detections] async def _send_data_to_log_analytics_workspace( self, validated_data: t.List[dict[str, t.Any]], last_detection: str | None, successful_data_upload: bool = False ) -> tuple[str | None, bool]: credential = DefaultAzureCredential() # Env vars: AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID client = LogsIngestionClient(endpoint=self.env_vars.endpoint_uri, credential=credential, logging_enable=True) async with client: try: await client.upload( rule_id=self.env_vars.dcr_immutableid, stream_name=self.env_vars.stream_name, logs=validated_data, # type: ignore[arg-type] ) last_detection = max(validated_data, key=lambda detection: detection.get("occurTime")).get("occurTime") # type: ignore successful_data_upload = True except ServiceRequestError as e: logging.error(f"Authentication to Azure service failed: {e}") except HttpResponseError as e: logging.error(f"Upload failed: {e}") await credential.close() return last_detection, successful_data_upload def _update_time_generated(self, validated_data: t.List[Detection]) -> None: utc_now = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") for data in validated_data: data.TimeGenerated = utc_now class StorageTableHandler: def __init__(self, env_conn_str: str, table_name_keys: str) -> None: self.conn_str = env_conn_str self.table_name_keys = table_name_keys self.entities = None self.table_client = None def with_table_client(func: t.Callable[[t.Any, t.Any], t.Any]) -> t.Callable[[t.Any], t.Any]: # type: ignore def wrapper(storage_table_handler_instance, *args, **kwargs): # type: ignore[no-untyped-def] try: with TableServiceClient.from_connection_string( conn_str=storage_table_handler_instance.conn_str ) as table_service_client: storage_table_handler_instance.table_client = table_service_client.create_table_if_not_exists( table_name=storage_table_handler_instance.table_name_keys ) return func(storage_table_handler_instance, *args, **kwargs) except ValueError as e: raise ValueError(f"Connection string WEBSITE_CONTENTAZUREFILECONNECTIONSTRING value error: {e}") return wrapper @with_table_client # type: ignore def set_entity(self) -> None: if self.table_client: self.entities = next(self.table_client.query_entities(""), None) return None @with_table_client def input_entity(self, new_entity: dict[str, t.Any]) -> None: entity = { "PartitionKey": self.table_name_keys, "RowKey": self.table_name_keys, "TimeGenerated": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), } | new_entity try: if self.table_client: ( self.table_client.update_entity(entity=entity) if self.entities else self.table_client.create_entity(entity=entity) ) self.entities = next(self.table_client.query_entities(""), None) logging.info(f"Entity: {self.table_name_keys} updated") except Exception as e: print("Exception occurred:", e) class LastDetectionTimeHandler: def __init__(self, storage_table_conn_str: str, env_last_occur_time: str, data_source: str) -> None: self.storage_table_name = f"LastDetectionTime{data_source}" self.storage_table_handler = StorageTableHandler(storage_table_conn_str, self.storage_table_name) self.storage_table_handler.set_entity() self.last_detection_time = self.get_last_occur_time(env_last_occur_time) def get_last_occur_time(self, env_last_occur_time: str) -> t.Any: if self.storage_table_handler.entities: return self.storage_table_handler.entities.get(self.storage_table_name) return env_last_occur_time def get_entity_schema(self, cur_last_detection_time: str) -> dict[str, t.Any]: return { self.storage_table_name: ( datetime.strptime( self.transform_date_with_miliseconds_to_second(cur_last_detection_time), "%Y-%m-%dT%H:%M:%SZ" ) + timedelta(seconds=1) ).isoformat() + "Z" } def transform_date_with_miliseconds_to_second(self, cur_last_detection_time: str) -> str: return cur_last_detection_time if len(cur_last_detection_time) == 20 else cur_last_detection_time[:-5] + "Z"